-
Notifications
You must be signed in to change notification settings - Fork 639
[SOT][CUDAGraph] Add support for custom all-reduce operators under SOT mode #4386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SOT][CUDAGraph] Add support for custom all-reduce operators under SOT mode #4386
Conversation
Thanks for your contribution! |
entry.captured = True | ||
with self.cuda_graph_manager.run_impl_guard(): | ||
entry.runnable(**kwargs) | ||
with capture_custom_allreduce(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
静态图也用custom all reduce 对吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对,目前静态图 Custom AllReduce 和 Paddle AllReduce 都支持
但是使用Custom AllReduce,需要加参数 --max-num-batched-tokens 500
,后面这个数小于500就行,具体原因后面继续排查~
self.proposer.update_task_chunk_prefill(task) | ||
task.chunk_idx += 1 | ||
|
||
@sot_warmup_guard(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SOT的 Warm Up 延后是为了避免 custom all reduce 的什么问题呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 custom_all_reduce
中,只有在 Capture 的时候才会走第一个 if self.capturing:
分支
但是,之前 SOT Warmup 在 Capture Model 前面,也就是IR图在 CUDAGraph Capture 之前就已经确定了,走 else
分支,这就导致 replay 的时候也走 else
分支,这是不对的
FastDeploy/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py
Lines 208 to 222 in 5abf597
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]: | |
"""The main allreduce API that provides support for cuda graph.""" | |
if self.capturing: | |
lib = cuda_wrapper.CudaRTLibrary() | |
stream = paddle.device.current_stream() | |
stream_capturing = lib.cudaStreamIsCapturing(stream) | |
if stream_capturing.value == 1: | |
# 1 is cudaStreamCaptureStatusActive: The stream is capturing. | |
return self.all_reduce(input, input, registered=True) | |
else: | |
# If warm up, mimic the allocation pattern since custom | |
# allreduce is out-of-place. | |
return paddle.empty_like(input) | |
else: | |
return self.all_reduce(input, input, registered=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Custom Allreduce 可以都增加到8192 * 1024 * 32 * 2 |
本 PR 的主要操作是将SOT warmup的过程延后,在CUDAGraph的Capture阶段进行warmup
本PR的前置PR
这个PR终于跑通 SOT + CUDAGraph + 开启子图切分的整个流程,在这个PR中梳理一下整个过程:
单卡模型
以下是 @zyfncg 的前置PR,SOT下实现CudaGraph子图捕获功能,我们快速跑通了 ERNIE45T 0.3B✅
但是依旧存在显存拷贝的问题,如果输入的位置发生改变,则将新输入Copy到Capture的位置
为了解决这个问题,#3302 添加了 append_attention_with_output,在运行 append_attention 之前,优先创建一个 empty Tensor 作为 append_attention 的输出——在外部管理 append_attention 的显存
#3694 修复了 #3302 导致的打断,#4340(是 #3694 的一部分)移除了一些无用的输出,避免动静不统一导致的BUG
#3694 依赖 Paddle 主框架的两个PR:
memcpy
&& Add CUDAGraph unitest Paddle#75078到此为止,单卡的 ERNIE45T 21B和0.3B 都能跑通✅,且不存在 Copy,但是多卡运行会遇到CUDA700的问题
多卡模型
遇到到第一个CUDA700问题,不开CUDAGraph,只开SOT就能复现
用 cuda-gdb 分析这个CUDA700的问题,定位到是 Custom Allreduce 的问题,可暂时通过
--max-num-batched-tokens 2000
规避与 @zhink 沟通后,可通过调大 Custom Allreduce 的 Buffer 大小来规避,即:
(后面发现这个增大 Buffer 的策略似乎无效了)
此时SOT + Custom Allreduce 可以跑通推理流程 ✅
但开启 CUDAGraph 遇到 mp_allreduce_sum 对应的 DeviceContext 存在 cudagraph allocator 为空指针的问题
我把所有 Instruction 对应的 DeviceContext 指针打印了出来,统计了一下,共有1个
CPUContext
+2个GPUContext
出现次数分别是几十次、几千次和几万次,而动态图+CUDAGraph只有1个
CPUContext
+1个GPUContext
定位到是
phi::DeviceContext* ParseDeviceContext
这个函数将GPUContext
(有cudagraph allocator)转化成了另一个GPUContext
(无cudagraph allocator),这里为了先跑通,就先直接return origin_dev_ctx;
了,(后续PR: PaddlePaddle#75954)但依旧会有 CUDA700 的问题中间其实也尝试了很多其他方法,和老代码 battle 了好久,遇到了多 CUDA Stream 的问题,CUDA90X之类的,这时,@zyfncg 说:
好,那就先解这个CUDA700,悲催的是,用之前的方法:cuda-gdb 分析,会出现连环 core dump 的问题,在生成 coredump 文件中,又报一个 CUDA700 🤦♂️
@zyfncg 又问了一个关键的问题:目前的问题是 Capture 阶段还是 Replay 阶段?我们把 Capture 过程中的 Replay 全都注释掉看看
把 Capture 过程中的 Replay 全都注释掉后,服务就能启动了,这其实说明是Replay阶段报的错,Capture阶段没问题
此时忽然想到是不是还是CustomAllreduce导致的问题,于是加上
--disable-custom-all-reduce
跑一下,没想到直接跑通了也就是 SOT+CUDAGraph+子图切分(禁用 CustomAllReduce)可以跑通整个流程✅
另外也可以确认,目前CUDA700的问题还是 CustomAllReduce 的问题,接下来就从动静统一的角度(动态图能跑通CUDAGraph,但SOT跑不通)来考虑差异在哪
最早怀疑的是,一些初始化操作(如下)静态图没有和动态图对齐
FastDeploy/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py
Lines 100 to 101 in 5abf597
但在SOT模式中,
__init__
函数还是以动态图跑的,没啥区别,也看了下函数实现,就是初始化 buffer + 创建一个 C++侧的CustomAllreduce 对象,并返回其指针到Python端其次怀疑的是,if 分支 这里的
should_custom_ar
函数:FastDeploy/fastdeploy/distributed/communication.py
Lines 53 to 71 in 5abf597
这个函数内选用 paddle.all_reduce 还是 custom_all_reduce,问题在这里吗?那先看看这个
should_custom_ar
:FastDeploy/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py
Lines 134 to 145 in 5abf597
刚好这个函数的调用在
@paddle.jit.marker.unified
这个装饰器之下,内部可以直接print,不用担心打断由于静态图中存在-1动态维度的情况,可以直接打印 inp.shape[0], inp.shape[1] 的值,只有 inp.shape[0] 存在为 -1 ,而
inp.shape[1] * inp.element_size()
刚好可以被16整除,所以inp_size % 16 == 0
,无需担心这个if分支动静不统一那
self.capturing
这个if 分支呢?之前没怎么注意过,在custom_all_reduce
中,只有在 Capture 的时候才会走第一个if self.capturing:
分支。但是,之前 SOT Warmup 在 Capture Model 前面,也就是IR图在 CUDAGraph Capture 之前就已经确定了,走else
分支,这就导致 replay 的时候也走else
分支,这是不对的。这里需要给SOT的执行过程也加上一个上下文管理器
capture_custom_allreduce
才能用这个self.capturing
,也就是如下FastDeploy/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py
Lines 136 to 138 in 1e59905
并在
capture_model
函数头上加@sot_warmup_guard(True)
这个sot warmup的装饰器:https://github.com/cattidea/FastDeploy/blob/1b9f351d219013f1a69db355736ee33bbc035866/fastdeploy/worker/gpu_model_runner.py#L1672-L1673
也就是这个 PR 中的内容,也就是 SOT+CUDAGraph+子图切分(不禁用 CustomAllReduce)整个流程可以跑通✅
遗留问题
多卡不开CUDAGraph,只开SOT会有这个问题:
启动参数加上
--max-num-batched-tokens 4000
就好了,但改成8000
会有问题,动态图则没有这个限制,还是 CustomAllReduce 导致的cc @SigureMo @zyfncg